import copy
import random
import numpy as np
import os
import logging
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from tqdm import tqdm
from nats_bench import create
import sys
import argparse

__all__ = ["ReinforceFinderNASBench201"]


class ArchManager:
    def __init__(self):
        self.operations = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
        self.num_ops = len(self.operations)
        self.num_edges = 6

    def random_sample(self):
        edge_ops = [random.randint(0, self.num_ops - 1) for _ in range(self.num_edges)]
        return self._ops_to_arch_str(edge_ops)

    def _arch_str_to_ops(self, arch_str):
        ops = []
        components = arch_str.split('+')
        
        edge = components[0].strip('|')
        op_name = edge.split('~')[0]
        ops.append(self.operations.index(op_name))
        
        edges = components[1].strip('|').split('|')
        for edge in edges:
            op_name = edge.split('~')[0]
            ops.append(self.operations.index(op_name))
        
        edges = components[2].strip('|').split('|')
        for edge in edges:
            op_name = edge.split('~')[0]
            ops.append(self.operations.index(op_name))
        
        return ops

    def _ops_to_arch_str(self, ops):
        if len(ops) != self.num_edges:
            raise ValueError(f"Expected {self.num_edges} operations, got {len(ops)}")
        
        arch_str = f"|{self.operations[ops[0]]}~0|+|{self.operations[ops[1]]}~0|{self.operations[ops[2]]}~1|+|{self.operations[ops[3]]}~0|{self.operations[ops[4]]}~1|{self.operations[ops[5]]}~2|"
        return arch_str


class AccuracyPredictor:
    def __init__(self, dataset='cifar10', metric='test'):
        try:
            # Load the NATS-Bench API for architecture evaluation
            self.api = create(None, 'tss', fast_mode=True, verbose=False)
            self.dataset = dataset
            self.metric = metric

            print(f"Successfully loaded NATS-Bench TSS API for {dataset} with {metric} metric")
        except Exception as e:
            print(f"Unable to load NATS-Bench TSS API: {e}")
            print("Please ensure you have downloaded the NATS-Bench TSS data file and set the correct path")
            sys.exit(1)

    def predict_accuracy(self, arch):
        try:
            arch_index = self.api.query_index_by_arch(arch)
            if self.dataset == 'cifar10' and self.metric == 'valid':
                results = self.api.get_more_info(arch_index, 'cifar10-valid', hp='200', is_random=False)
            else:
                results = self.api.get_more_info(arch_index, self.dataset, hp='200', is_random=False)
                
            if self.metric == 'valid':
                return results['valid-accuracy']
            else:
                return results['test-accuracy']
        except Exception as e:
            print(f"Error querying architecture performance: {e}")
            return 0.0


class ReinforcePolicy(nn.Module):
    def __init__(self, num_edges=6, num_ops=5):
        super(ReinforcePolicy, self).__init__()
        self.num_edges = num_edges
        self.num_ops = num_ops
        
        # Create a policy network for each edge in the architecture
        self.edge_policies = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1, 64),
                nn.ReLU(),
                nn.Linear(64, num_ops)
            ) for _ in range(num_edges)
        ])
        
        for edge_policy in self.edge_policies:
            for layer in edge_policy:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    nn.init.constant_(layer.bias, 0.0)
    
    def forward(self):
        probs = []
        for i, edge_policy in enumerate(self.edge_policies):
            edge_input = torch.tensor([[i]], dtype=torch.float32)
            logits = edge_policy(edge_input)
            prob = nn.functional.softmax(logits, dim=1)
            probs.append(prob)
        return probs
    
    def sample_arch(self):
        probs = self.forward()
        actions = []
        log_probs = []
        
        for prob in probs:
            m = Categorical(prob)
            action = m.sample()
            log_prob = m.log_prob(action)
            
            actions.append(action.item())
            log_probs.append(log_prob)
            
        return actions, torch.stack(log_probs)


class ExponentialMovingAverage:
    def __init__(self, momentum=0.9):
        self._numerator = 0
        self._denominator = 0
        self._momentum = momentum

    def update(self, value):
        self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
        self._denominator = self._momentum * self._denominator + (1 - self._momentum)

    def value(self):
        return self._numerator / self._denominator if self._denominator > 0 else 0


class ReinforceFinderNASBench201:
    def __init__(
        self,
        dataset='cifar10',
        logger=None,
        **kwargs
    ):
        self.dataset = dataset
        self.arch_manager = ArchManager()
        self.metric = kwargs.get("metric", "test")
        self.accuracy_predictor = AccuracyPredictor(dataset, self.metric)
        self.logger = logger

        self.learning_rate = kwargs.get("learning_rate", 0.01)
        self.max_samples = kwargs.get("max_samples", 100)
        self.ema_momentum = kwargs.get("ema_momentum", 0.9)
        self.seed = kwargs.get("seed", 0)
        
        self.policy = ReinforcePolicy(
            num_edges=self.arch_manager.num_edges,
            num_ops=self.arch_manager.num_ops
        )
        self.optimizer = optim.Adam(self.policy.parameters(), lr=self.learning_rate)
        # Baseline for variance reduction in policy gradient
        self.baseline = ExponentialMovingAverage(self.ema_momentum)
        
        self.arch_performances = {}
        
        self.total_explored = 0
        self.best_arch = None
        self.best_acc = 0.0
        self.best_found_at = 0
        
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
            print(f"Random seed set to: {self.seed}")

    def evaluate_arch(self, arch_str):
        if arch_str not in self.arch_performances:
            acc = self.accuracy_predictor.predict_accuracy(arch_str)
            self.arch_performances[arch_str] = acc
        else:
            acc = self.arch_performances[arch_str]
            
        self.total_explored += 1
        
        if acc > self.best_acc:
            self.best_acc = acc
            self.best_arch = arch_str
            self.best_found_at = self.total_explored
            self.logger.info(f"New best architecture found! Acc: {acc:.4f}, Explored: {self.total_explored}")
            self.logger.info(f"Architecture: {arch_str}")
                
        return acc

    def run_reinforce_search(self):
        self.logger.info(f"Starting REINFORCE search for NAS-Bench-201 on {self.dataset} with {self.metric} metric...")
        self.logger.info(f"Learning rate: {self.learning_rate}, EMA momentum: {self.ema_momentum}")
        self.logger.info(f"Max samples: {self.max_samples}")
        
        best_valids = []
        
        pbar = tqdm(total=self.max_samples)
        while self.total_explored < self.max_samples:
            # Sample architecture from current policy
            actions, log_probs = self.policy.sample_arch()
            arch_str = self.arch_manager._ops_to_arch_str(actions)
            
            # Evaluate the sampled architecture
            reward = self.evaluate_arch(arch_str)
            
            # Update baseline with current reward
            self.baseline.update(reward)
            baseline_value = self.baseline.value()
            
            # Calculate advantage and policy loss for REINFORCE update
            advantage = reward - baseline_value
            policy_loss = -(log_probs * advantage).sum()
            
            # Update policy network
            self.optimizer.zero_grad()
            policy_loss.backward()
            self.optimizer.step()
            
            best_valids.append(self.best_acc)
            
            if self.total_explored % 10 == 0 or self.total_explored == self.max_samples:
                self.logger.info(f"\nSample {self.total_explored}/{self.max_samples}:")
                self.logger.info(f"Current arch: {arch_str}")
                self.logger.info(f"Accuracy: {reward:.4f}, Baseline: {baseline_value:.4f}")
                self.logger.info(f"Policy loss: {policy_loss.item():.4f}")
                self.logger.info(f"Overall best: {self.best_acc:.4f}")
            
            pbar.update(1)
        
        pbar.close()
        
        self.logger.info(f"\nREINFORCE search completed!")
        self.logger.info(f"Total explored architectures: {self.total_explored}")
        self.logger.info(f"Best architecture found at #{self.best_found_at} (after {self.best_found_at/self.total_explored:.2%} of total):")
        self.logger.info(f"Accuracy: {self.best_acc:.4f}")
        self.logger.info(f"Architecture: {self.best_arch}")
            
        return best_valids, [self.best_acc, self.best_arch], self.best_found_at


def main():
    parser = argparse.ArgumentParser(description='Search for high-performance architectures in NAS-Bench-201 using REINFORCE')
    parser.add_argument('--samples', type=int, default=100, help='Maximum number of architectures to explore')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'ImageNet16-120'],
                        help='Dataset to optimize architecture for')
    parser.add_argument('--metric', type=str, default='test', choices=['test', 'valid'],
                        help='Which metric to use for evaluation (test or validation accuracy)')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate for REINFORCE')
    parser.add_argument('--ema_momentum', type=float, default=0.9, help='Momentum for the baseline EMA')
    parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility')
    args = parser.parse_args()
    
    current_file = os.path.basename(__file__).split('.')[0]
    log_dir = f"search_logs/{current_file}/{args.dataset}-{args.metric}"
    os.makedirs(log_dir, exist_ok=True)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir, f"seed{args.seed}_{args.dataset}_{args.metric}_lr{args.learning_rate}_s{args.samples}_{timestamp}.log")),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger("NAS-REINFORCE")
    
    finder = ReinforceFinderNASBench201(
        dataset=args.dataset,
        logger=logger,
        max_samples=args.samples,
        learning_rate=args.learning_rate,
        ema_momentum=args.ema_momentum,
        seed=args.seed,
        metric=args.metric
    )
    
    print(f"Starting REINFORCE search on {args.dataset} with {args.metric} metric, max samples: {args.samples}")
    best_valids, best_info, best_found_at = finder.run_reinforce_search()
    
    print("\nBest architecture:")
    print(f"Best architecture found at sample: #{best_found_at} (after {best_found_at/finder.total_explored:.2%} of total)")
    print(f"Architecture string: {best_info[1]}")
    print(f"Accuracy on {args.dataset} ({args.metric}): {best_info[0]:.4f}")


if __name__ == "__main__":
    main() 